[MLIR] Fix crash in RemoveDeadValues when terminator lacks RegionBranchTerminatorOpInterface#175300
[MLIR] Fix crash in RemoveDeadValues when terminator lacks RegionBranchTerminatorOpInterface#175300nataliakokoromyti wants to merge 1 commit intollvm:mainfrom
Conversation
…chTerminatorOpInterface The RemoveDeadValues pass was using cast<RegionBranchTerminatorOpInterface> which asserts that the terminator implements this interface. However, some dialects (like CIR) have region terminators that don't implement this interface, causing crashes. This patch changes cast to dyn_cast and skips processing for terminators that don't implement RegionBranchTerminatorOpInterface. Fixes llvm#174502
|
@llvm/pr-subscribers-mlir-core Author: None (nataliakokoromyti) ChangesThe RemoveDeadValues pass was using cast<RegionBranchTerminatorOpInterface> which asserts that the terminator implements this interface. However, some dialects (like CIR) have region terminators that don't implement this interface, causing crashes. This patch changes cast to dyn_cast and skips processing for terminators that don't implement RegionBranchTerminatorOpInterface. Fixes #174502 Full diff: https://github.com/llvm/llvm-project/pull/175300.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index fc2c2acf8afd3..4d20abb415229 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -477,8 +477,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
// TODO: this isn't correct in face of multiple terminators.
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
+ if (!terminator)
+ continue;
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
@@ -498,11 +500,17 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
+ auto terminatorIface =
terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
+ ? dyn_cast<RegionBranchTerminatorOpInterface>(terminator)
+ : nullptr;
+ // If terminator doesn't implement RegionBranchTerminatorOpInterface,
+ // we can't analyze it, so skip.
+ if (terminator && !terminatorIface)
+ return;
+ RegionBranchPoint point =
+ terminatorIface ? RegionBranchPoint(terminatorIface)
+ : RegionBranchPoint::parent();
for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
@@ -566,8 +574,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
+ if (!terminator)
+ continue;
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
|
|
@llvm/pr-subscribers-mlir Author: None (nataliakokoromyti) ChangesThe RemoveDeadValues pass was using cast<RegionBranchTerminatorOpInterface> which asserts that the terminator implements this interface. However, some dialects (like CIR) have region terminators that don't implement this interface, causing crashes. This patch changes cast to dyn_cast and skips processing for terminators that don't implement RegionBranchTerminatorOpInterface. Fixes #174502 Full diff: https://github.com/llvm/llvm-project/pull/175300.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index fc2c2acf8afd3..4d20abb415229 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -477,8 +477,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
// TODO: this isn't correct in face of multiple terminators.
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
+ if (!terminator)
+ continue;
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
@@ -498,11 +500,17 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
+ auto terminatorIface =
terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
+ ? dyn_cast<RegionBranchTerminatorOpInterface>(terminator)
+ : nullptr;
+ // If terminator doesn't implement RegionBranchTerminatorOpInterface,
+ // we can't analyze it, so skip.
+ if (terminator && !terminatorIface)
+ return;
+ RegionBranchPoint point =
+ terminatorIface ? RegionBranchPoint(terminatorIface)
+ : RegionBranchPoint::parent();
for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
@@ -566,8 +574,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
+ if (!terminator)
+ continue;
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
|
The RemoveDeadValues pass was using cast which asserts that the terminator implements this interface. However, some dialects (like CIR) have region terminators that don't implement this interface, causing crashes.
This patch changes cast to dyn_cast and skips processing for terminators that don't implement RegionBranchTerminatorOpInterface.
Fixes #174502